# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from typing import Optional, Union

import torch

try:
    import transformer_engine  # pylint: disable=unused-import

    from megatron.core.extensions.transformer_engine import TENorm

    NORM_IMPL = TENorm
except:
    NORM_IMPL = torch.nn.LayerNorm

from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.models.vision.clip_vit_model import CLIPViTModel
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.tensor_parallel.mappings import (
    gather_from_sequence_parallel_region,
    scatter_to_sequence_parallel_region,
)
from megatron.core.utils import get_batch_on_this_cp_rank


def clip_vit_init(
    self,
    transformer_config: TransformerConfig,
    transformer_layer_spec: ModuleSpec,
    ln_pre_impl: Union[ModuleSpec, type] = NORM_IMPL,
    ln_post_impl: Union[ModuleSpec, type] = NORM_IMPL,
    add_class_token: bool = True,
    class_token_len: int = 1,
    patch_dim: int = 14,
    img_h: int = 336,
    img_w: int = 336,
    model_subtype: str = "clip",
) -> None:

    error_msg = f"CLIPViTModel model subtype {model_subtype} is not supported."
    assert model_subtype in ["clip", "siglip", "internvit"], error_msg

    if model_subtype == "siglip":
        assert class_token_len == 0, "SigLIP does not support class tokens."
        assert not add_class_token, "SigLIP does not support class tokens."

    super(CLIPViTModel, self).__init__(config=transformer_config)

    if has_config_logger_enabled(transformer_config):
        log_config_to_disk(transformer_config, locals(), prefix=type(self).__name__)

    self.class_token_len = class_token_len
    self.visual_hidden_size = transformer_config.hidden_size
    self.patch_dim = patch_dim
    self.img_h = img_h
    self.img_w = img_w

    assert self.img_h % self.patch_dim == 0
    assert self.img_w % self.patch_dim == 0
    self.num_patches_per_dim_h = self.img_h // self.patch_dim
    self.num_patches_per_dim_w = self.img_w // self.patch_dim
    self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w

    self.add_class_token = add_class_token
    self.class_token_len = class_token_len

    self.seq_length = self.num_patches + (self.class_token_len if self.add_class_token else 0)

    self.ln_pre = None
    self.ln_post = None
    if model_subtype == "clip":
        self.ln_pre = build_module(
            ln_pre_impl,
            config=transformer_config,
            hidden_size=self.visual_hidden_size,
            eps=transformer_config.layernorm_epsilon,
        )
        conv_bias = False
        padding = 0
    elif model_subtype == "siglip":
        self.ln_post = build_module(
            ln_post_impl,
            config=transformer_config,
            hidden_size=self.visual_hidden_size,
            eps=transformer_config.layernorm_epsilon,
        )
        conv_bias = True
        padding = "valid"
    elif model_subtype == "internvit":
        conv_bias = True
        padding = 0
    else:
        raise ValueError(f"unsupported vision model type {model_subtype}")

    self.conv1 = torch.nn.Conv2d(
        in_channels=3,
        out_channels=self.visual_hidden_size,
        kernel_size=self.patch_dim,
        stride=self.patch_dim,
        bias=conv_bias,
        padding=padding,
    )

    self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
    # MPATCH BEGINS
    self.position_ids = get_batch_on_this_cp_rank({"key": self.position_ids})["key"]
    # MPATCH ENDS

    self.position_embeddings = torch.nn.Embedding(self.seq_length, self.visual_hidden_size)

    self.add_class_token = add_class_token
    if self.add_class_token:
        self.class_token = torch.nn.Parameter(
            torch.randn(1, self.class_token_len, self.visual_hidden_size)
        )

    self.model_type = ModelType.encoder_or_decoder

    # Transformer layers.
    # TODO: Make pre_process and post_process configurable.
    # NOTE: a final layer norm and/or linear layer in some implementations are omitted here.
    # They can be added separately where needed.
    self.decoder = TransformerBlock(
        config=transformer_config,
        spec=transformer_layer_spec,
        pre_process=True,
        post_process=False,
    )



def clip_vit_forward(
    self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """Forward function of the CLIP ViT Model. This function passes the input tensors
    through the embedding layer and then the transformer.

    Args:
        x (torch.Tensor): input data of shape [batch, img_h, img_w]
        attention_mask (torch.Tensor with dtype=bool): Attention mask to use.

    Returns:
        x (torch.Tensor): output after final transformer block of shape [b, s, h].
    """
    x = self.conv1(x)  # shape = [batch, hidden_size, grid, grid]
    x = x.reshape(x.shape[0], x.shape[1], -1)  # [batch, hidden_size, grid ** 2]
    x = x.permute(0, 2, 1)  # [batch, grid ** 2, hidden_size]

    if self.add_class_token:
        class_token = self.class_token.expand(
            x.shape[0], -1, -1
        )  # [batch, class_token_len, hidden_size]
        x = torch.cat(
            [class_token, x], dim=1
        )  # [batch, grid ** 2 + class_token_len, hidden_size]

    # MPATCH BEGINS
    assert x.shape[1] * self.config.context_parallel_size == self.seq_length, \
        f"{x.shape[1] * self.config.context_parallel_size} != {self.seq_length}"
    # MPATCH ENDS
    x = x + self.position_embeddings(self.position_ids)
    if self.ln_pre:
        x = self.ln_pre(x)
    x = x.permute(1, 0, 2)  # [b, s, h] -> [s, b, h]
    # `permute` can make the tensor non-contiguous, breaking pipelining.
    x = x.contiguous()

    # MPATCH BEGINS
    if self.config.sequence_parallel:
        x = scatter_to_sequence_parallel_region(x)
    x = self.decoder(x, attention_mask)
    if self.config.sequence_parallel:
        x = gather_from_sequence_parallel_region(x, tensor_parallel_output_grad=False)
    # MPATCH ENDS

    x = x.permute(1, 0, 2)  # [s, b, h] -> [b, s, h]
    x = x.contiguous()
    if self.ln_post:
        x = self.ln_post(x)
    return x
